from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from typing import List, Iterable
import unicodedata
from torch.utils.data import Dataset
from config import *
import opencc

cc = opencc.OpenCC('t2s')


# all_letters = string.ascii_letters + " .,;'"
# 为便于数据处理，把Unicode字符串转换为ASCII编码
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'  # and c in all_letters
    )


def filterPair(p):
    return len(p[0]) < MAX_LENGTH and \
           len(p[1]) < MAX_LENGTH \
        # and p[0].startswith(eng_prefixes)


# eng_prefixes = (
#     "i am ", "i'm ",
#     "he is", "he's ",
#     "she is", "she's ",
#     "you are", "you're ",
#     "we are", "we're ",
#     "they are", "they're "
# )
token_transform_en = get_tokenizer('spacy', language='en_core_web_sm')
token_transform_zh = get_tokenizer('spacy', language='zh_core_web_sm')


class MyData(Dataset):
    def __init__(self, path=DATA_ROOT + 'eng-cmn.txt'):  # 构造函数
        self.pairs = []
        for line in open(path, encoding='utf-8'):
            pair = cc.convert(line).split('\t')
            pair[0] = token_transform_en(unicodeToAscii(pair[0].lower().strip()))
            pair[1] = token_transform_zh(unicodeToAscii(pair[1].strip()))
            if filterPair(pair):
                self.pairs.append(pair)

    # pairs= [["eng0", "Chi0"],["eng1", "Chi1"],[],[]...]

    def __getitem__(self, idx):  # 读取数据
        return self.pairs[idx][0], self.pairs[idx][1]

    def __add__(self, other):
        return len(self.pairs)


# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, id: int) -> List[str]:
    for data_sample in data_iter:
        yield data_sample[id]


def vocab_generate():
    data = MyData()
    eng = build_vocab_from_iterator(yield_tokens(data, 0), specials=special_symbols)
    zh = build_vocab_from_iterator(yield_tokens(data, 1), specials=special_symbols)
    return eng, zh


if __name__ == "__main__":
    eng, zh = vocab_generate()
    torch.save(eng, DATA_ROOT + "eng_vocab")
    torch.save(zh, DATA_ROOT + "zh_vocab")
